{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sympy as sp\n",
    "import matplotlib.pyplot as plt\n",
    "from pyodesys.symbolic import PartiallySolvedSystem, ScaledSys, symmetricsys, get_logexp\n",
    "from pyodesys.plotting import plot_result as _plot_res\n",
    "from pycvodes import fpes\n",
    "from chempy import ReactionSystem, Substance\n",
    "from chempy.kinetics.ode import get_odesys\n",
    "sp.init_printing()\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsys = ReactionSystem.from_string(\n",
    "    \"\"\"\n",
    "    A -> B; 0.04\n",
    "    B + C -> A + C; 1e4\n",
    "    2 B -> B + C; 3e7\n",
    "    \"\"\",\n",
    "    substance_factory=lambda formula: Substance(formula, composition={1: 1}),\n",
    "    substances='ABC')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsys.composition_balance_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dep_scaling = 1e8\n",
    "orisys, extra = get_odesys(rsys, description='original')\n",
    "scaledsys = ScaledSys.from_other(orisys, dep_scaling=dep_scaling, description='scaled')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orisys.linear_invariants, orisys.names, orisys.latex_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indep_end = 1e18\n",
    "init_conc = {'A': 1, 'B': 0, 'C': 0}\n",
    "integrate_kw = dict(integrator='cvode', atol=1e-6, rtol=1e-6, nsteps=5000, record_rhs_xvals=True, record_jac_xvals=True)\n",
    "\n",
    "def integrate_systems(systems, **kwargs):\n",
    "    for k, v in integrate_kw.items():\n",
    "        if k not in kwargs:\n",
    "            kwargs[k] = v\n",
    "    return [odesys.integrate(indep_end, init_conc, record_order=True, record_fpe=True,\n",
    "                             return_on_error=True, **kwargs) for odesys in systems]\n",
    "\n",
    "def descr_str(info):\n",
    "    keys = 'n_steps nfev time_cpu success'.split()\n",
    "    if isinstance(info, dict):\n",
    "        nfo = info\n",
    "    else:\n",
    "        nfo = {}\n",
    "        for d in info:\n",
    "            for k in keys:\n",
    "                if k in nfo:\n",
    "                    if k == 'success':\n",
    "                        nfo[k] = nfo[k] and d[k]\n",
    "                    else:\n",
    "                        nfo[k] += d[k]\n",
    "                else:\n",
    "                    nfo[k] = d[k]\n",
    "    return ' (%d steps, %d rhs evals., %.5g s CPUt)' % tuple([nfo[k] for k in keys[:3]]) + (\n",
    "        '' if nfo['success'] else ' - failed!')\n",
    "\n",
    "def plot_result(description, res, ax=None, vline_post_proc=None, colors=('k', 'r', 'g'), xlim=None):\n",
    "    if ax is None:\n",
    "        ax = plt.subplot(1, 1, 1)\n",
    "    vk = 'steps rhs_xvals jac_xvals'.split()  # 'fe_underflow fe_overflow fe_invalid fe_divbyzero'\n",
    "    res.plot(xscale='log', yscale='log', ax=ax, c=colors,\n",
    "             info_vlines_kw=dict(vline_keys=vk, post_proc=vline_post_proc))\n",
    "    lines = ax.get_lines()\n",
    "    for idx, val in enumerate([1 - 0.04*1e-7, 0.04*1e-7]):\n",
    "        ax.plot(1e-7, val, marker='o', c=colors[idx])\n",
    "    for idx, val in enumerate([0.2083340149701255e-7, 0.8333360770334713e-13, 0.9999999791665050]):\n",
    "        ax.plot(1e11, val, marker='o', c=colors[idx])\n",
    "    ax.legend(loc='best')\n",
    "    ax.set_title(description + descr_str(res.info))\n",
    "    if xlim is not None:\n",
    "        ax.set_xlim(xlim)\n",
    "    \n",
    "def plot_results(systems, results, axes=None, **kwargs):\n",
    "    if axes is None:\n",
    "        _fig, axes = plt.subplots(2, 2, figsize=(14, 14))\n",
    "    for idx, (odesys, res) in enumerate(zip(systems, results)):\n",
    "        plot_result(odesys.description, res, ax=axes.flat[idx], **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results([orisys, scaledsys], integrate_systems([orisys, scaledsys], first_step=1e-23),\n",
    "             axes=plt.subplots(1, 2, figsize=(14, 7))[1], xlim=(1e-23, indep_end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "psys = [PartiallySolvedSystem.from_linear_invariants(scaledsys, [k], description=k) for k in 'ABC']\n",
    "scaled_linear = [scaledsys] + psys\n",
    "[cs.dep for cs in psys]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Switch formulation using roots:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orisys['A']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rssys_A = PartiallySolvedSystem.from_linear_invariants(\n",
    "    scaledsys, ['A'], description='A root finding',\n",
    "    roots=[10*scaledsys['A'] - scaledsys['C']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_A = rssys_A.integrate(indep_end, init_conc, return_on_root=True, **integrate_kw)\n",
    "res_C = psys[2].integrate(indep_end - res_A.xout[-1], res_A.yout[-1, :], **integrate_kw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 3, figsize=(16, 4))\n",
    "plt_kw = dict(xscale='log', yscale='log')\n",
    "res_A.plot(**plt_kw, ax=axes[0])\n",
    "res_C.plot(**plt_kw, ax=axes[1])\n",
    "_plot_res(np.concatenate((res_A.xout, res_C.xout[1:] + res_A.xout[-1])),\n",
    "          np.concatenate((res_A.yout, res_C.yout[1:])),\n",
    "          ax=axes[2], names='ABC', **plt_kw)\n",
    "for descr, ax, info in zip(['A', 'C', 'A/C'], axes, [res_A.info, res_C.info, (res_A.info, res_C.info)]):\n",
    "    ax.legend(loc='best', prop={'size': 11})\n",
    "    ax.set_title(descr + descr_str(info))\n",
    "    ax.set_xlim([1e-9, indep_end])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not switching:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linresults = integrate_systems(scaled_linear)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results(scaled_linear, linresults, xlim=(1e-9, indep_end))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_result('scaled A/C ', rssys_A.integrate(\n",
    "        indep_end, init_conc, return_on_root=True, **integrate_kw).extend_by_integration(\n",
    "            indep_end, odesys=scaled_linear[3], **integrate_kw), xlim=(1e-9, indep_end))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logarithmic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indep0 = 1e-26\n",
    "def symmetric_factory(exprs_process_cb=lambda exprs: [sp.powsimp(expr.expand(), force=True) for expr in exprs]):\n",
    "    return symmetricsys(get_logexp(1, 1e-20, b2=0), get_logexp(1, indep0, b2=0), \n",
    "                        exprs_process_cb=exprs_process_cb, check_transforms=False)\n",
    "LogLogSys_together = symmetric_factory()\n",
    "stlogsystems = [LogLogSys_together.from_other(ls, description='stl ' + ls.description) for ls in scaled_linear]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stlogresults = integrate_systems(stlogsystems, nsteps=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_results(stlogsystems, stlogresults, vline_post_proc=lambda x: np.abs(np.exp(x) - indep0), xlim=[indep0*10, indep_end])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stlogsystems[0].exprs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unscaled_linear = [orisys] + [PartiallySolvedSystem.from_linear_invariants(\n",
    "        orisys, [k], description=k) for k in 'ABC']\n",
    "utlogsystems = [LogLogSys_together.from_other(ls, description='utl ' + ls.description) for ls in unscaled_linear]\n",
    "utlogresults = integrate_systems(utlogsystems, nsteps=500)\n",
    "plot_results(utlogsystems, utlogresults, vline_post_proc=lambda x: np.abs(np.exp(x) - indep0), xlim=[indep0*10, indep_end])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "utlogsystems[0].exprs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LogLogSys_apart = symmetric_factory(None)\n",
    "salogsystems = [LogLogSys_apart.from_other(ls, description='sal ' + ls.description) for ls in scaled_linear]\n",
    "salogresults = integrate_systems(salogsystems, nsteps=500)\n",
    "plot_results(salogsystems, salogresults, vline_post_proc=lambda x: np.abs(np.exp(x) - indep0), xlim=[indep0*10, indep_end])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ualogsystems = [LogLogSys_apart.from_other(ls, description='ual ' + ls.description) for ls in unscaled_linear]\n",
    "ualogresults = integrate_systems(ualogsystems, nsteps=500)\n",
    "plot_results(ualogsystems, ualogresults, vline_post_proc=lambda x: np.abs(np.exp(x) - indep0), xlim=[indep0*10, indep_end])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rsalogsys_A = LogLogSys_apart.from_other(rssys_A, autonomous_interface=True)\n",
    "assert rsalogsys_A.autonomous_interface\n",
    "rsalogsys_A.roots, rsalogsys_A.exprs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_log_switch = rsalogsys_A.integrate(indep_end, init_conc, return_on_root=True, **integrate_kw)\n",
    "res_log_switch.extend_by_integration(indep_end, odesys=salogsystems[3], return_on_error=True,\n",
    "                                     autonomous=True, **integrate_kw)\n",
    "fig, ax = plt.subplots(1, 1, figsize=(14, 4))\n",
    "ax.axvline(res_log_switch.xout[res_log_switch.info['root_indices'][0]], ls='--')\n",
    "res_log_switch.plot(ax=ax, xscale='log', yscale='log')\n",
    "ax.legend(loc='best', prop={'size': 11})\n",
    "ax.set_title('log A/C ' + descr_str(res_log_switch.info))\n",
    "_ = ax.set_xlim([1e-9, indep_end])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(16, 4))\n",
    "plot_result('log A/C', res_log_switch, ax, vline_post_proc=lambda x: np.abs(np.exp(x) - indep0))\n",
    "_ = ax.set_xlim([1e-20, 1e12])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
